from datasets import load_dataset, load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer,  LlamaTokenizer, TrainerCallback, AutoConfig, BitsAndBytesConfig
import argparse
import torch
import os
from copy import deepcopy
import numpy as np
from accelerate import Accelerator
import pprint
import deepspeed
import sys
from rmjpo_trainer_rankedChoice import DPOTrainer
from rmjpo_config_rankedChoice import DPOConfig


class MemoryCleanCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()  



def main(): 

    parser = argparse.ArgumentParser(description="RCPO Training Script")
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--data_path", type=str)
    parser.add_argument("--loss_type", type=str)
    parser.add_argument("--per_device_train_batch_size", type=int)
    parser.add_argument("--gradient_accumulation_steps", type=int)
    parser.add_argument("--num_train_epochs", type=int)
    parser.add_argument("--beta", type=float)
    parser.add_argument("--max_length", type=int) 
    parser.add_argument("--max_prompt_length", type=int) 
    parser.add_argument("--max_completion_length", type=int) 
    parser.add_argument("--save_steps", type=int)
    parser.add_argument("--logging_steps", type=int)
    parser.add_argument("--learning_rate", type=float)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--logging_dir", type=str)
    parser.add_argument("--ranked_dpo_finetuned_model_saved_dir", type=str)
    parser.add_argument("--output_reference_dispersion_local_dir", type=str)
    parser.add_argument("--rankedchoice_length", type=int)
    args = parser.parse_args()



    model_name_or_path = args.model_name_or_path

    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
    model.config.use_cache = False  
    model.gradient_checkpointing_enable()

    ref_model = deepcopy(model)
    ref_model.requires_grad_(False) 


    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"


    processed_dataset = load_from_disk(args.data_path)
    
    if args.loss_type == "rankedchoice_rmj_dpo":
        output_reference_dispersion_flag = True
    else:
        output_reference_dispersion_flag = False


    dpo_args = DPOConfig(
    loss_type=args.loss_type,
    max_length=args.max_length,
    max_prompt_length=args.max_prompt_length,
    max_completion_length=args.max_completion_length,
    per_device_train_batch_size=args.per_device_train_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    output_reference_dispersion=output_reference_dispersion_flag,
    beta = args.beta,
    learning_rate = args.learning_rate,
    num_train_epochs=args.num_train_epochs,
    bf16=True,                  
    save_strategy="steps",     
    save_steps=args.save_steps,           
    output_dir=args.output_dir,
    gradient_checkpointing=True, 
    logging_steps=args.logging_steps,     
    logging_dir=args.logging_dir,
    lr_scheduler_type = "cosine",
    warmup_ratio = 0.2,
    remove_unused_columns = False, # must set as false
    rankedchoice_length = args.rankedchoice_length, 
    optim = "adamw_torch",
    max_grad_norm = 1.0,
    )



    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=dpo_args,
        train_dataset=processed_dataset,
        processing_class=tokenizer,
    )

   


    trainer.add_callback(MemoryCleanCallback())

    trainer.train()
  
    trainer.save_model(args.ranked_dpo_finetuned_model_saved_dir)  
    tokenizer.save_pretrained(args.ranked_dpo_finetuned_model_saved_dir)
    if trainer.output_reference_dispersion:
        all_reference_dispersion = torch.cat(trainer.all_reference_dispersion).float().numpy()
        np.save(args.output_reference_dispersion_local_dir, all_reference_dispersion)


if __name__ == "__main__":
    main()


